package com.learn;

import java.util.Scanner;

public class TestMain {

    public static void main(String[] args) {

        // int num = new Scanner(System.in).nextInt();
        // int n = new Scanner(System.in).nextInt();
        // int num = 1;
        // for (int i = 0; i < n; i++) {
        //     for (int j = 0; j < n; j++) {
        //         System.out.print(num + " ");
        //         num += n;
        //     }
        //     System.out.println();
        //     num = i + 2;
        // }

        int n = 4;
        int[][] arr = new int[n][n];
        int row = 0, col = 0;
        int num = 1;

        while (num <= n * n) {
            arr[row][col] = num;

            if ((row + col) % 2 == 0) {
                if (col == n - 1) {
                    row++;
                } else if (row == 0) {
                    col++;
                } else {
                    row--;
                    col++;
                }
            } else {
                if (row == n - 1) {
                    col++;
                } else if (col == 0) {
                    row++;
                } else { // 正常情况：行加1，列减1
                    row++;
                    col--;
                }
            }

            num++;
        }

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(arr[i][j] + " ");
            }
            System.out.println();
        }
    }


}
